{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"NOTES: Batch data is different each time in keras, which result in slight differences in results.\n",
    "Here is the file tests with BAFNet model. The original code was shown below\"\"\"\n",
    "\n",
    "\"\"\"Bettycxh, \"Bottleneck-Attention-Based-Fusion-Network-for-Sleep-Apnea-Detection,\n",
    "\" GitHub repository, n.d. [Online]. \n",
    "Available: https://github.com/Bettycxh/Bottleneck-Attention-Based-Fusion-Network-for-Sleep-Apnea-Detection\n",
    "\"\"\"\n",
    "import time\n",
    "import pickle\n",
    "import keras\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from keras.layers import Conv1D, Dense, Dropout, MaxPooling1D,Reshape,multiply,Permute,\\\n",
    "              GlobalAveragePooling1D,BatchNormalization,Flatten,UpSampling1D,Conv1DTranspose,\\\n",
    "                Flatten,  Lambda, Input\n",
    "from keras.models import Model,load_model\n",
    "from keras.regularizers import l2\n",
    "from scipy.interpolate import splev, splrep\n",
    "from keras.activations import sigmoid\n",
    "from keras.callbacks import LearningRateScheduler,ModelCheckpoint\n",
    "from keras.utils import np_utils\n",
    "from IPython.display import SVG,display,HTML\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "import keras.backend as K\n",
    "from sklearn.metrics import confusion_matrix,f1_score,roc_auc_score\n",
    "import random\n",
    "from sklearn.model_selection import train_test_split\n",
    "import math\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.interpolate import CubicSpline\n",
    "def interpolate_numpy_array(arr, desired_length):\n",
    "    cs = CubicSpline(np.linspace(0, 1, len(arr)), arr)\n",
    "    x_new = np.linspace(0, 1, desired_length)\n",
    "    interpolated_arr = cs(x_new)\n",
    "    return interpolated_arr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kmu-BHHmm8I_"
   },
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "executionInfo": {
     "elapsed": 507,
     "status": "ok",
     "timestamp": 1693229193098,
     "user": {
      "displayName": "Hiếu Nguyễn Xuân",
      "userId": "09184859202144170734"
     },
     "user_tz": -420
    },
    "id": "IslyMKFQm8JB"
   },
   "outputs": [],
   "source": [
    "base_dir = \"dataset\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1\"\n",
    "ir = 3 # interpolate interval\n",
    "before = 2\n",
    "after = 2\n",
    "# normalize\n",
    "scaler = lambda arr: (arr - np.min(arr)) / (np.max(arr) - np.min(arr))\n",
    "def load_data():\n",
    "    tm = np.arange(0, (before + 1 + after) * 60, step=1 / float(ir))\n",
    "\"\"\"\n",
    "We can change the file pkl for each case: T1,T2,T3,T4 in here.\n",
    "\"\"\"\n",
    "    with open(os.path.join(base_dir, \"T_1.pkl\"), 'rb') as f: # read preprocessing result\n",
    "        apnea_ecg = pickle.load(f)\n",
    "    x,x_train,x_val = [],[],[]\n",
    "    o_train, y_train = apnea_ecg[\"o_train\"], apnea_ecg[\"y_train\"]\n",
    "    groups_train = apnea_ecg[\"groups_train\"]\n",
    "    for i in range(len(o_train)):\n",
    "        min_distance_list, max_distance_list, mean_distance_list = o_train[i]\n",
    "\t\t# Curve interpolation\n",
    "        min_distance_list_inter = interpolate_numpy_array(min_distance_list,900)\n",
    "        max_distance_list_inter = interpolate_numpy_array(max_distance_list,900)\n",
    "        mean_distance_list_inter = interpolate_numpy_array(mean_distance_list,900)\n",
    "\"\"\"\n",
    "In this part we design the ablation to test the relationship of each terms: MinDP,MaxDP, MeanDP. By add this value in\n",
    "the list x we can represent all the case in the paper including: M1,M2,M3,M4,M5,M6,M7.\n",
    "MinDP: min_distance_list_inter\n",
    "MaxDP: max_distance_list_inter\n",
    "MeanDP: mean_distance_list_inter\n",
    "\"\"\"\n",
    "        x.append([min_distance_list_inter, max_distance_list_inter])\n",
    "    groups_training,groups_val=[],[]\n",
    "\n",
    "    num=[i for i in range(16713)]\n",
    "    trainlist, vallist,y_train, y_val = train_test_split(num,y_train, test_size=0.3,random_state=42,stratify =y_train)\n",
    "    print()\n",
    "    for i in trainlist:\n",
    "        x_train.append(x[i])\n",
    "        groups_training.append(groups_train[i])\n",
    "    for i in vallist:\n",
    "        x_val.append(x[i])\n",
    "        groups_val.append(groups_train[i])\n",
    "\n",
    "    x_train = np.array(x_train, dtype=\"float32\").transpose((0, 2, 1)) # convert to numpy format\n",
    "    y_train= np.array(y_train, dtype=\"float32\")\n",
    "    x_val = np.array(x_val, dtype=\"float32\").transpose((0, 2, 1)) # convert to numpy format\n",
    "    y_val = np.array(y_val, dtype=\"float32\")\n",
    "\n",
    "    x_test = []\n",
    "    o_test, y_test = apnea_ecg[\"o_test\"], apnea_ecg[\"y_test\"]\n",
    "    groups_test = apnea_ecg[\"groups_test\"]\n",
    "    for i in range(len(o_test)):\n",
    "        min_distance_list, max_distance_list, standard_deviation_distance_list = o_test[i]\n",
    "\t\t# Curve interpolation\n",
    "        min_distance_list_inter = interpolate_numpy_array(min_distance_list,900)\n",
    "        max_distance_list_inter = interpolate_numpy_array(max_distance_list,900)\n",
    "        mean_distance_list_inter = interpolate_numpy_array(mean_distance_list,900)\n",
    "        x_test.append([min_distance_list_inter, max_distance_list_inter])\n",
    "    x_test = np.array(x_test, dtype=\"float32\").transpose((0, 2, 1))\n",
    "    y_test = np.array(y_test, dtype=\"float32\")\n",
    "\n",
    "    return x_train,y_train, groups_training,x_val, y_val, groups_val, x_test, y_test, groups_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 35872,
     "status": "ok",
     "timestamp": 1693229233286,
     "user": {
      "displayName": "Hiếu Nguyễn Xuân",
      "userId": "09184859202144170734"
     },
     "user_tz": -420
    },
    "id": "7pJOzNkzm8JC",
    "outputId": "8e7e9cc5-520c-49c5-e7fe-cae0037c9d4e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "input_shape (11699, 900, 2)\n"
     ]
    }
   ],
   "source": [
    "x_train,y_train, groups_train,x_val, y_val, groups_val, x_test, y_test, groups_test= load_data()\n",
    "y_train = np_utils.to_categorical(y_train, num_classes=2) # Convert to two categories\n",
    "y_val = np_utils.to_categorical(y_val, num_classes=2)\n",
    "y_test = np_utils.to_categorical(y_test, num_classes=2)\n",
    "print('input_shape',x_train.shape)\n",
    "#rri_train: min_distance_list_inter\n",
    "#ampl_train: max_distance_list_inter\n",
    "#We only change the input in the original \n",
    "rri_train=np.expand_dims(x_train[:,:,0],axis=2)\n",
    "ampl_train=np.expand_dims(x_train[:,:,1],axis=2)\n",
    "rri_val=np.expand_dims(x_val[:,:,0],axis=2)\n",
    "ampl_val=np.expand_dims(x_val[:,:,1],axis=2)\n",
    "rri_test=np.expand_dims(x_test[:,:,0],axis=2)\n",
    "ampl_test=np.expand_dims(x_test[:,:,1],axis=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7MADDevGm8JC"
   },
   "source": [
    "# BAFNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 903,
     "status": "ok",
     "timestamp": 1693229249510,
     "user": {
      "displayName": "Hiếu Nguyễn Xuân",
      "userId": "09184859202144170734"
     },
     "user_tz": -420
    },
    "id": "7-oHdvBkm8JD"
   },
   "outputs": [],
   "source": [
    "class ScaledDotProductAttention(keras.layers.Layer):\n",
    "    def __init__(self,\n",
    "                 return_attention=False,\n",
    "                 history_only=False,\n",
    "                 **kwargs):\n",
    "        super(ScaledDotProductAttention, self).__init__(**kwargs)\n",
    "        self.supports_masking = True\n",
    "        self.return_attention = return_attention\n",
    "        self.history_only = history_only\n",
    "        self.intensity = self.attention = None\n",
    "\n",
    "    def get_config(self):\n",
    "        config = {\n",
    "            'return_attention': self.return_attention,\n",
    "            'history_only': self.history_only,\n",
    "        }\n",
    "        base_config = super(ScaledDotProductAttention, self).get_config()\n",
    "        return dict(list(base_config.items()) + list(config.items()))\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        if isinstance(input_shape, list):\n",
    "            query_shape, key_shape, value_shape = input_shape\n",
    "        else:\n",
    "            query_shape = key_shape = value_shape = input_shape\n",
    "        output_shape = query_shape[:-1] + value_shape[-1:]\n",
    "        if self.return_attention:\n",
    "            attention_shape = query_shape[:2] + (key_shape[1],)\n",
    "            return [output_shape, attention_shape]\n",
    "        return output_shape\n",
    "\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        if isinstance(mask, list):\n",
    "            mask = mask[0]\n",
    "        if self.return_attention:\n",
    "            return [mask, None]\n",
    "        return mask\n",
    "\n",
    "    def call(self, inputs, mask=None, **kwargs):\n",
    "        if isinstance(inputs, list):\n",
    "            query, key, value = inputs\n",
    "        else:\n",
    "            query = key = value = inputs\n",
    "        if isinstance(mask, list):\n",
    "            mask = mask[1]\n",
    "        feature_dim = K.shape(query)[-1]\n",
    "        e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))\n",
    "        if self.history_only:\n",
    "            query_len, key_len = K.shape(query)[1], K.shape(key)[1]\n",
    "            indices = K.expand_dims(K.arange(0, key_len), axis=0)\n",
    "            upper = K.expand_dims(K.arange(0, query_len), axis=-1)\n",
    "            e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0)\n",
    "        if mask is not None:\n",
    "            e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx()))\n",
    "        self.intensity = e\n",
    "        e = K.exp(e - K.max(e, axis=-1, keepdims=True))\n",
    "        self.attention = e / K.sum(e, axis=-1, keepdims=True)\n",
    "        v = K.batch_dot(self.attention, value)\n",
    "        if self.return_attention:\n",
    "            return [v, self.attention]\n",
    "        return v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1693229253281,
     "user": {
      "displayName": "Hiếu Nguyễn Xuân",
      "userId": "09184859202144170734"
     },
     "user_tz": -420
    },
    "id": "3lSL0Vg0m8JD"
   },
   "outputs": [],
   "source": [
    "def create_model(input_shape,weight=1e-3):\n",
    "    inp=Input(shape=input_shape)\n",
    "    input1 =Reshape((900, 1))(inp[:,:,0])\n",
    "    input2 = Reshape((900, 1))(inp[:,:,1])\n",
    "\n",
    "    x1 = Conv1D(16, kernel_size=11, strides=1, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(weight), bias_regularizer=l2(weight))(input1)\n",
    "    x2 = Conv1D(16, kernel_size=11, strides=1, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(weight), bias_regularizer=l2(weight))(input2)\n",
    "\n",
    "    x1 = Conv1D(24, kernel_size=11, strides=2, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(x1)\n",
    "    x1 = MaxPooling1D(pool_size=3, padding=\"same\")(x1)\n",
    "    x2 = Conv1D(24, kernel_size=11, strides=2, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(x2)\n",
    "    x2 = MaxPooling1D(pool_size=3, padding=\"same\")(x2)\n",
    "    fsn2=keras.layers.concatenate([x1, x2], name=\"fsn2\", axis=-1)\n",
    "\n",
    "    x1 = Conv1D(32 , kernel_size=11, strides=1, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(x1)\n",
    "    x1 = MaxPooling1D(pool_size=5, padding=\"same\")(x1)\n",
    "    x2 = Conv1D(32, kernel_size=11, strides=1, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(x2)\n",
    "    x2 = MaxPooling1D(pool_size=5, padding=\"same\")(x2)\n",
    "    fsn3=Conv1D(32, kernel_size=11, strides=1, padding=\"same\", activation=\"relu\", kernel_initializer=\"he_normal\",\n",
    "                  kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(fsn2)\n",
    "    fsn3=MaxPooling1D(pool_size=5, padding=\"same\")(fsn3)\n",
    "    fsn3=ScaledDotProductAttention()([fsn3,fsn3,fsn3])\n",
    "    x1=ScaledDotProductAttention()([fsn3,x1,x1])\n",
    "    x2=ScaledDotProductAttention()([fsn3,x2,x2])\n",
    "\n",
    "    # concat\n",
    "    concat = keras.layers.concatenate([x1, x2], name=\"Concat_Layer_x1\", axis=-1)\n",
    "\n",
    "    # FCN_1\n",
    "    FCN1 = UpSampling1D(5)(x1)\n",
    "    FCN1 = Conv1DTranspose(24, kernel_size=11, strides=1, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(FCN1)\n",
    "    FCN1 = UpSampling1D(3)(FCN1)\n",
    "    FCN1 = Conv1DTranspose(16, kernel_size=11, strides=2, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(FCN1)\n",
    "    FCN1 = Conv1DTranspose(1, kernel_size=11, strides=1, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight),name='rri')(FCN1)\n",
    "\n",
    "    # FCN_2\n",
    "    FCN2 = UpSampling1D(5)(x2)\n",
    "    FCN2 = Conv1DTranspose(24, kernel_size=11, strides=1, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(FCN2)\n",
    "    FCN2 = UpSampling1D(3)(FCN2)\n",
    "    FCN2 = Conv1DTranspose(16, kernel_size=11, strides=2, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight))(FCN2)\n",
    "    FCN2 = Conv1DTranspose(1, kernel_size=11, strides=1, padding=\"same\", kernel_initializer=\"he_normal\",\n",
    "                              kernel_regularizer=l2(1e-3), bias_regularizer=l2(weight),name='ampl')(FCN2)\n",
    "\n",
    "    #Channel-wise fusion module\n",
    "    squeeze = GlobalAveragePooling1D()(concat)\n",
    "    excitation=Dense(32,activation='relu')(squeeze)\n",
    "    excitation=Dense(64,activation='sigmoid')(excitation)\n",
    "    excitation = Reshape((1, 64))(excitation)\n",
    "    scale = multiply([concat, excitation])\n",
    "\n",
    "    # Classification\n",
    "    x = GlobalAveragePooling1D(name='GAP')(scale)\n",
    "    outputs=Dense(2,activation='softmax',name=\"outputs\")(x)\n",
    "    model = Model(inputs=inp, outputs=[outputs,FCN1,FCN2])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f5Y_0rkvm8JE"
   },
   "source": [
    "# Training stage 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1693229257050,
     "user": {
      "displayName": "Hiếu Nguyễn Xuân",
      "userId": "09184859202144170734"
     },
     "user_tz": -420
    },
    "id": "Cf6jbmt4m8JE"
   },
   "outputs": [],
   "source": [
    "def lr_schedule(epoch, lr):\n",
    "    if epoch > 70 and (epoch - 1) % 10 == 0:\n",
    "        lr *= 0.1\n",
    "    print(\"Learning rate: \", lr)\n",
    "    return lr\n",
    "\n",
    "def plot(history):\n",
    "    \"\"\"Plot performance curve\"\"\"\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "    print(history)\n",
    "    axes[0].plot(history[\"outputs_loss\"], \"r-\", history[\"val_outputs_loss\"], \"b-\", linewidth=0.5)\n",
    "    axes[0].set_title(\"Loss\")\n",
    "    axes[1].plot(history[\"outputs_accuracy\"], \"r-\", history[\"val_outputs_accuracy\"], \"b-\", linewidth=0.5)\n",
    "    axes[1].set_title(\"Accuracy\")\n",
    "    fig.tight_layout()\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Please change the filepath of your device to save the result\"\"\"\n",
    "if __name__ == \"__main__\":\n",
    "    model=create_model(x_train.shape[1:])\n",
    "    # compile\n",
    "    model.compile(loss={'outputs': 'binary_crossentropy','rri':'mean_squared_error','ampl':'mean_squared_error'},loss_weights={\n",
    "                  'outputs': 1,'rri': 1,'ampl':1}, optimizer='adam', metrics={'outputs':'accuracy'})\n",
    "    filepath='/content/drive/MyDrive/Final Result/Final_performance/BAFNet/model/min_mean_1.hdf5'\n",
    "    checkpoint = ModelCheckpoint(filepath, monitor='val_outputs_accuracy', verbose=1, save_best_only=True,mode='max')\n",
    "    lr_scheduler = LearningRateScheduler(lr_schedule)\n",
    "    callbacks_list = [lr_scheduler, checkpoint]\n",
    "    time_begin = time.time()\n",
    "    history = model.fit(x_train, [y_train,ampl_train,rri_train], batch_size=128, epochs=100,\n",
    "                        validation_data=(x_val, [y_val,ampl_val,rri_val]),callbacks=callbacks_list)\n",
    "    time_end = time.time()\n",
    "    t = time_end - time_begin\n",
    "    print('time_train:', t)\n",
    "    plot(history.history)\n",
    "\n",
    "#     test\n",
    "#     model= load_model(filepath,custom_objects={'ScaledDotProductAttention':ScaledDotProductAttention})\n",
    "\n",
    "    r= model.evaluate(x_test, [y_test,ampl_test,rri_test])\n",
    "    loss=r[0]\n",
    "    acc=r[-1]\n",
    "    # save prediction score\n",
    "    y_score = model.predict(x_test, batch_size=1024, verbose=1)[0]\n",
    "    roc=roc_auc_score(y_score=y_score,y_true=y_test)\n",
    "    output = pd.DataFrame({\"y_true\": y_test[:, 1], \"y_score\": y_score[:, 1], \"subject\": groups_test})\n",
    "# Export to csv file\n",
    "    output.to_csv(\"/content/drive/MyDrive/Final Result/Final_performance/BAFNet/file_name.csv\", index=False)\n",
    "    y=model.predict(x_test, batch_size=1024, verbose=1)[0]\n",
    "    y_true, y_pred = np.argmax(y_test, axis=-1), np.argmax(y, axis=-1)\n",
    "    C = confusion_matrix(y_true, y_pred, labels=(1, 0))\n",
    "    TP, TN, FP, FN = C[0, 0], C[1, 1], C[1, 0], C[0, 1]\n",
    "    acc, sn, sp = 1. * (TP + TN) / (TP + TN + FP + FN), 1. * TP / (TP + FN), 1. * TN / (TN + FP)\n",
    "    f1=f1_score(y_true, y_pred, average='binary')\n",
    "\n",
    "    print(\"TP:{}, TN:{}, FP:{}, FN:{}, loss{}, acc{}, sn{}, sp{}, f1{}, roc{}\".format(TP, TN, FP, FN,loss, acc, sn, sp, f1, roc))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
